{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The Bayesian Models\n",
    "\n",
    "Perhaps one of the most interesting functionality in the library is the access to full Bayesian models in almost exactly the same way one would use any of the other models in the library. \n",
    "\n",
    "Note however that the Bayesian models are **ONLY** available for tabular data and, at the moment, we do not support combining them to form a Wide and Deep model. \n",
    "\n",
    "The implementation in this library is based on the publication: [Weight Uncertainty in Neural Networks](https://arxiv.org/pdf/1505.05424.pdf), by Blundell et al., 2015. Code-wise, our implementation is inspired by a number of source: \n",
    "\n",
    "1. https://joshfeldman.net/WeightUncertainty/\n",
    "2. https://www.nitarshan.com/bayes-by-backprop/\n",
    "3. https://github.com/piEsposito/blitz-bayesian-deep-learning\n",
    "4. https://github.com/zackchase/mxnet-the-straight-dope/tree/master/chapter18_variational-methods-and-uncertainty\n",
    "\n",
    "The two Bayesian models available in the library are: \n",
    "\n",
    "1. BayesianWide: this is a linear model where the non-linearities are captured via crossed-columns\n",
    "2. BayesianMLP: this is a standard MLP that receives categorical embeddings and continuous cols (embedded or not) which are the passed through a series of dense layers. All parameters in the model are probabilistic."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "from pytorch_widedeep.metrics import Accuracy\n",
    "from pytorch_widedeep.datasets import load_adult\n",
    "from pytorch_widedeep.callbacks import EarlyStopping, ModelCheckpoint\n",
    "from pytorch_widedeep.preprocessing import TabPreprocessor, WidePreprocessor\n",
    "from pytorch_widedeep.bayesian_models import BayesianWide, BayesianTabMlp\n",
    "from pytorch_widedeep.training.bayesian_trainer import BayesianTrainer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first few things to do we know them very well, like with any other model described in any of the other notebooks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>educational_num</th>\n",
       "      <th>marital_status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>gender</th>\n",
       "      <th>capital_gain</th>\n",
       "      <th>capital_loss</th>\n",
       "      <th>hours_per_week</th>\n",
       "      <th>native_country</th>\n",
       "      <th>age_buckets</th>\n",
       "      <th>income_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>25</td>\n",
       "      <td>Private</td>\n",
       "      <td>226802</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>89814</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Farming-fishing</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>United-States</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>28</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>336951</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Protective-serv</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>44</td>\n",
       "      <td>Private</td>\n",
       "      <td>160323</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>7688</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>18</td>\n",
       "      <td>?</td>\n",
       "      <td>103497</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>?</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  workclass  fnlwgt     education  educational_num      marital_status  \\\n",
       "0   25    Private  226802          11th                7       Never-married   \n",
       "1   38    Private   89814       HS-grad                9  Married-civ-spouse   \n",
       "2   28  Local-gov  336951    Assoc-acdm               12  Married-civ-spouse   \n",
       "3   44    Private  160323  Some-college               10  Married-civ-spouse   \n",
       "4   18          ?  103497  Some-college               10       Never-married   \n",
       "\n",
       "          occupation relationship   race  gender  capital_gain  capital_loss  \\\n",
       "0  Machine-op-inspct    Own-child  Black    Male             0             0   \n",
       "1    Farming-fishing      Husband  White    Male             0             0   \n",
       "2    Protective-serv      Husband  White    Male             0             0   \n",
       "3  Machine-op-inspct      Husband  Black    Male          7688             0   \n",
       "4                  ?    Own-child  White  Female             0             0   \n",
       "\n",
       "   hours_per_week native_country age_buckets  income_label  \n",
       "0              40  United-States           0             0  \n",
       "1              50  United-States           3             0  \n",
       "2              40  United-States           1             1  \n",
       "3              40  United-States           4             1  \n",
       "4              30  United-States           0             0  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = load_adult(as_frame=True)\n",
    "df.columns = [c.replace(\"-\", \"_\") for c in df.columns]\n",
    "df[\"age_buckets\"] = pd.cut(\n",
    "    df.age, bins=[16, 25, 30, 35, 40, 45, 50, 55, 60, 91], labels=np.arange(9)\n",
    ")\n",
    "df[\"income_label\"] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n",
    "df.drop(\"income\", axis=1, inplace=True)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = train_test_split(df, test_size=0.2, stratify=df.income_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide_cols = [\n",
    "    \"age_buckets\",\n",
    "    \"education\",\n",
    "    \"relationship\",\n",
    "    \"workclass\",\n",
    "    \"occupation\",\n",
    "    \"native_country\",\n",
    "    \"gender\",\n",
    "]\n",
    "crossed_cols = [(\"education\", \"occupation\"), (\"native_country\", \"occupation\")]\n",
    "\n",
    "cat_embed_cols = [\n",
    "    \"workclass\",\n",
    "    \"education\",\n",
    "    \"marital_status\",\n",
    "    \"occupation\",\n",
    "    \"relationship\",\n",
    "    \"race\",\n",
    "    \"gender\",\n",
    "    \"capital_gain\",\n",
    "    \"capital_loss\",\n",
    "    \"native_country\",\n",
    "]\n",
    "continuous_cols = [\"age\", \"hours_per_week\"]\n",
    "\n",
    "target = train[\"income_label\"].values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. `BayesianWide`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)\n",
    "X_tab = wide_preprocessor.fit_transform(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BayesianWide(\n",
    "    input_dim=np.unique(X_tab).shape[0],\n",
    "    prior_sigma_1=1.0,\n",
    "    prior_sigma_2=0.002,\n",
    "    prior_pi=0.8,\n",
    "    posterior_mu_init=0,\n",
    "    posterior_rho_init=-7.0,\n",
    "    pred_dim=1,  # here the models are NOT passed to a WideDeep constructor class so the output dim MUST be specified\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = BayesianTrainer(\n",
    "    model,\n",
    "    objective=\"binary\",\n",
    "    optimizer=torch.optim.Adam(model.parameters(), lr=0.01),\n",
    "    metrics=[Accuracy],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|███████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 124.32it/s, loss=163, metrics={'acc': 0.7813}]\n",
      "valid: 100%|███████████████████████████████████████████████████████████████| 31/31 [00:00<00:00, 238.67it/s, loss=141, metrics={'acc': 0.8219}]\n",
      "epoch 2: 100%|███████████████████████████████████████████████████████████| 123/123 [00:00<00:00, 132.81it/s, loss=140, metrics={'acc': 0.8285}]\n",
      "valid: 100%|███████████████████████████████████████████████████████████████| 31/31 [00:00<00:00, 190.16it/s, loss=140, metrics={'acc': 0.8298}]\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(\n",
    "    X_tab=X_tab,\n",
    "    target=target,\n",
    "    val_split=0.2,\n",
    "    n_epochs=2,\n",
    "    batch_size=256,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. `BayesianTabMlp`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/javierrodriguezzaurin/Projects/pytorch-widedeep/pytorch_widedeep/preprocessing/tab_preprocessor.py:358: UserWarning: Continuous columns will not be normalised\n",
      "  warnings.warn(\"Continuous columns will not be normalised\")\n"
     ]
    }
   ],
   "source": [
    "tab_preprocessor = TabPreprocessor(\n",
    "    cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols\n",
    ")\n",
    "X_tab = tab_preprocessor.fit_transform(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BayesianTabMlp(\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    cat_embed_input=tab_preprocessor.cat_embed_input,\n",
    "    continuous_cols=continuous_cols,\n",
    "    #     embed_continuous_method = \"standard\",\n",
    "    #     cont_embed_activation=\"leaky_relu\",\n",
    "    #     cont_embed_dim = 8,\n",
    "    mlp_hidden_dims=[128, 64],\n",
    "    prior_sigma_1=1.0,\n",
    "    prior_sigma_2=0.002,\n",
    "    prior_pi=0.8,\n",
    "    posterior_mu_init=0,\n",
    "    posterior_rho_init=-7.0,\n",
    "    pred_dim=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = BayesianTrainer(\n",
    "    model,\n",
    "    objective=\"binary\",\n",
    "    optimizer=torch.optim.Adam(model.parameters(), lr=0.01),\n",
    "    metrics=[Accuracy],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|███████████████████████████████████████████████████████████| 123/123 [00:04<00:00, 28.74it/s, loss=2e+3, metrics={'acc': 0.8007}]\n",
      "valid: 100%|███████████████████████████████████████████████████████████| 31/31 [00:00<00:00, 136.89it/s, loss=1.75e+3, metrics={'acc': 0.8418}]\n",
      "epoch 2: 100%|████████████████████████████████████████████████████████| 123/123 [00:04<00:00, 29.41it/s, loss=1.73e+3, metrics={'acc': 0.8596}]\n",
      "valid: 100%|███████████████████████████████████████████████████████████| 31/31 [00:00<00:00, 143.87it/s, loss=1.71e+3, metrics={'acc': 0.8569}]\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(\n",
    "    X_tab=X_tab,\n",
    "    target=target,\n",
    "    val_split=0.2,\n",
    "    n_epochs=2,\n",
    "    batch_size=256,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These models are powerful beyond the success metrics because they give us a sense of uncertainty as we predict. Let's have a look"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_tab_test = tab_preprocessor.transform(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "predict: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:01<00:00, 33.92it/s]\n"
     ]
    }
   ],
   "source": [
    "preds = trainer.predict(X_tab_test, return_samples=True, n_samples=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 9769)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "as we can see the prediction have shape `(5, 9769)`, one set of predictions each time we have internally run predict (i.e. sample the network and predict, defined by the parameter `n_samples`). This gives us an idea of how certain the model is about a certain prediction.\n",
    "\n",
    "Similarly, we could obtain the probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "predict: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:01<00:00, 32.79it/s]\n"
     ]
    }
   ],
   "source": [
    "probs = trainer.predict_proba(X_tab_test, return_samples=True, n_samples=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 9769, 2)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "probs.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And we could see how the model performs each time we sampled the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8559729757395844\n",
      "0.8564847988535162\n",
      "0.8567918927218753\n",
      "0.8562800696079435\n",
      "0.8558706111167981\n"
     ]
    }
   ],
   "source": [
    "for p in preds:\n",
    "    print(accuracy_score(p, test[\"income_label\"].values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
